/*******************************************************************************
 * Copyright (c) 2015-2018 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package face5wap;

import lombok.EqualsAndHashCode;
import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.impl.UniformDistribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossUtil;
import org.nd4j.linalg.ops.transforms.Transforms;

/**
 * Wasserstein loss function, which calculates the Wasserstein distance, also known as earthmover's distance.
 *
 * This is not necessarily a general purpose loss function, and is intended for use as a discriminator loss.
 *
 * When using in a discriminator, use a label of 1 for real and -1 for generated
 * instead of the 1 and 0 used in normal GANs.
 *
 * As described in <a href="https://papers.nips.cc/paper/5679-learning-with-a-wasserstein-loss.pdf">Learning with a Wasserstein Loss</a>
 *
 * @author Ryan Nett
 */
@EqualsAndHashCode(callSuper = false)
public class NewLossGradientPenalty implements ILossFunction {

    private INDArray scoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask){
        if(!labels.equalShapes(preOutput)){
            Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", labels.shape(), preOutput.shape());
        }
        labels = labels.castTo(preOutput.dataType());   //No-op if already correct dtype
        //Computes gradient penalty based on prediction and weighted real / fake samples
        //将一起放入的数据分开
        long rows = labels.shape()[0];// [60,1]
        INDArray labelsA = labels.get(NDArrayIndex.all(), NDArrayIndex.interval(0,rows/2));
        INDArray labelsB = labels.get(NDArrayIndex.all(), NDArrayIndex.interval(rows/2,rows));
        INDArray imgA = preOutput.get(NDArrayIndex.all(), NDArrayIndex.interval(0,rows/2));
        INDArray imgB = preOutput.get(NDArrayIndex.all(), NDArrayIndex.interval(rows/2,rows));




        INDArray output = activationFn.getActivation(preOutput.dup(), true);
        INDArray gradients = activationFn.backprop(labels, output).getFirst();
        // compute the euclidean norm by squaring ...
        INDArray gradients_sqr = gradients.mul(gradients);
        //  ... summing over the rows ...
        INDArray axis = Nd4j.arange(1, gradients_sqr.shape().length);
        for(int i=0;i<axis.rows();i++){
            gradients_sqr = Nd4j.sum(gradients_sqr,axis.getInt(i));
        }
        INDArray gradients_sqr_sum = gradients_sqr;
        //  ... and sqrt
        INDArray gradient_l2_norm = Transforms.sqrt(gradients_sqr_sum);
        //compute lambda * (1 - ||grad||)^2 still for each single sample
        INDArray gradient_l2_norm_sqr = Nd4j.ones(gradient_l2_norm.shape()).sub(gradient_l2_norm);
        INDArray gradient_penalty = gradient_l2_norm_sqr.mul(gradient_l2_norm_sqr);
        //return the mean as loss over all the batch samples
        if (mask != null) {
            LossUtil.applyMask(gradient_penalty, mask);
        }
        return gradient_penalty;
       /* return Nd4j.mean(gradient_penalty);*/

       /* labels = labels.castTo(preOutput.dataType());   //No-op if already correct dtype

        INDArray output = activationFn.getActivation(preOutput.dup(), true);

        INDArray scoreArr = labels.mul(output);
        if (mask != null) {
            LossUtil.applyMask(scoreArr, mask);
        }
        return scoreArr;*/
    }

    public static INDArray randomWeightedAverage(int batch,INDArray real,INDArray fake){
        // INDArray alpha = Nd4j.rand(new UniformDistribution(0,1),new long[]{batch, 1, 1, 1});// new NDRandom().uniform(32, 1, DataType.FLOAT, new long[]{32, 1, 1, 1});
        INDArray alpha = Nd4j.rand(new UniformDistribution(0,1),new long[]{batch, 784});// new NDRandom().uniform(32, 1, DataType.FLOAT, new long[]{32, 1, 1, 1});
        return (alpha.muli(real)).addi((Nd4j.ones(alpha.shape()).subi(alpha)).muli(fake));
    }

    @Override
    public double computeScore(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask,
            boolean average) {
        INDArray scoreArr = scoreArray(labels, preOutput, activationFn, mask);

        double score = scoreArr.mean().sumNumber().doubleValue();

        if (average) {
            score /= scoreArr.size(0);
        }

        return score;
    }

    @Override
    public INDArray computeScoreArray(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
        INDArray scoreArr = scoreArray(labels, preOutput, activationFn, mask);
        return Nd4j.expandDims(scoreArr.mean(), 1);
    }

    @Override
    public INDArray computeGradient(INDArray labels, INDArray preOutput, IActivation activationFn, INDArray mask) {
        if(!labels.equalShapes(preOutput)){
            Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", labels.shape(), preOutput.shape());
        }
        labels = labels.castTo(preOutput.dataType());   //No-op if already correct dtype
        INDArray dLda = labels.div(labels.size(1));

        if (mask != null && LossUtil.isPerOutputMasking(dLda, mask)) {
            LossUtil.applyMask(labels, mask);
        }

        INDArray grad = activationFn.backprop(preOutput, dLda).getFirst();

        if (mask != null) {
            LossUtil.applyMask(grad, mask);
        }

        return grad;
    }

    @Override
    public Pair<Double, INDArray> computeGradientAndScore(INDArray labels, INDArray preOutput, IActivation activationFn,
            INDArray mask, boolean average) {
        return new Pair<>(computeScore(labels, preOutput, activationFn, mask, average),
                computeGradient(labels, preOutput, activationFn, mask));
    }
    @Override
    public String name() {
        return toString();
    }

    @Override
    public String toString() {
        return "LossWasserstein()";
    }
}
